'''
Author: SlytherinGe
LastEditTime: 2021-04-01 16:02:55
'''
import torch

# resnet50
# if __name__ == '__main__':
    
#     ori_model_file = torch.load('/home/gejunyao/.cache/torch/hub/checkpoints/backup/resnet50-19c8e357.pth')
#     out_model_file = torch.load('/media/gejunyao/Disk/Gejunyao/develop/temp/temp.pth')
#     for key, value in ori_model_file.items():
#         if key == 'conv1.weight':
#             out_model_file['stem1.0.weight'] = value
#             out_model_file['stem2.0.weight'] = value
#         elif key == 'bn1.running_mean':
#             out_model_file['stem1.1.running_mean'] = value
#             out_model_file['stem2.1.running_mean'] = value
#         elif key == 'bn1.running_var':
#             out_model_file['stem1.1.running_var'] = value
#             out_model_file['stem2.1.running_var'] = value
#         elif key == 'bn1.weight':
#             out_model_file['stem1.1.weight'] = value
#             out_model_file['stem2.1.weight'] = value
#         elif key == 'bn1.bias':
#             out_model_file['stem1.1.bias'] = value
#             out_model_file['stem2.1.bias'] = value
#         else:
#             corresponding_out_key_1 = ''
#             corresponding_out_key_2 = ''
#             key_component = key.split('.')
#             corresponding_out_key_1 = key_component[0]+'_1'
#             corresponding_out_key_2 = key_component[0]+'_2'
#             for i in range(1, len(key_component)):
#                 corresponding_out_key_1 += ('.'+key_component[i])
#                 corresponding_out_key_2 += ('.'+key_component[i])
#             out_model_file[corresponding_out_key_1] = value
#             out_model_file[corresponding_out_key_2] = value
#     torch.save(out_model_file, '/media/gejunyao/Disk/Gejunyao/develop/pretrained_model/resnet50-twoway.pth')

#darknet53
# if __name__ == '__main__':
    
#     ori_model_file = torch.load('/home/gejunyao/.cache/torch/hub/checkpoints/backup/darknet53-a628ea1b.pth')
#     out_model_file = torch.load('/media/gejunyao/Disk/Gejunyao/develop/temp/temp.pth')
#     for key, value in ori_model_file['state_dict'].items():
#         if key[:5] == 'conv1':
#             out_model_file[key] = value
#             new_key = key[:4] + '2' + key[5:]
#             out_model_file[new_key] = value
#         else:
#             key1 = key[:4] + '1' + key[4:]
#             key2 = key[:4] + '2' + key[4:]
#             out_model_file[key1] = value
#             out_model_file[key2] = value
#     torch.save(out_model_file, '/media/gejunyao/Disk/Gejunyao/develop/pretrained_model/darknet53-twoway.pth')

# VGG16
if __name__ == '__main__':
    
    ori_model_file = torch.load('/home/gejunyao/.cache/torch/hub/checkpoints/backup/vgg16_caffe-292e1171.pth')
    out_model_file = torch.load('/media/gejunyao/Disk/Gejunyao/develop/temp/temp.pth')
    for key, value in ori_model_file.items():
        key1 = key[:8] + '1' + key[8:]
        key2 = key[:8] + '2' + key[8:]
        out_model_file[key1] = value
        out_model_file[key2] = value
    # for key, value in out_model_file.items():
    #     print(key)
    torch.save(out_model_file, '/media/gejunyao/Disk/Gejunyao/develop/pretrained_model/vgg16-twoway.pth')
